#Setup
import os
import time
import math
os.environ["KERAS_BACKEND"] = "jax"  # @param ["tensorflow", "jax", "torch"]
import keras
from keras import layers
from keras import ops
from keras.callbacks import Callback
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from src.standard_attention import StandardMultiHeadAttention
from src.optimised_attention import OptimisedAttention
from src.efficient_attention import EfficientAttention
from src.super_attention import SuperAttention

# Using GPU 1 only
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

#Prepare the data
num_classes = 10
input_shape = (28, 28, 1)
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# expand the dimensions of the images to (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

#Configure the hyperparameters
num_epochs = 50  # For real training, use num_epochs=100. 10 is a test value
batch_size = 2048

max_learning_rate = 0.001
min_learning_rate = 0.00002
weight_decay = 0.0001
image_size = 32  # We'll resize input images to this size
patch_size = 4  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = num_patches * 2  # Embedding size


transformer_units = [
    # projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 3
mlp_head_units = [
    96,
    48,
]  # Size of the dense layers of the final classifier


#Use data augmentation
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomRotation(factor=0.05),
        layers.RandomZoom(height_factor=0.1, width_factor=0.1),
        layers.RandomTranslation(height_factor=0.1, width_factor=0.1),
        layers.GaussianNoise(0.05),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

class TimeHistory(Callback):
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, epoch, logs={}):
        self.epoch_start_time = time.time()

    def on_epoch_end(self, epoch, logs={}):
        self.times.append(time.time() - self.epoch_start_time)

#Implement multilayer perceptron (MLP)
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=keras.activations.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

#Implement patch creation as a layer
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        input_shape = ops.shape(images)
        batch_size = input_shape[0]
        height = input_shape[1]
        width = input_shape[2]
        channels = input_shape[3]
        num_patches_h = height // self.patch_size
        num_patches_w = width // self.patch_size
        patches = keras.ops.image.extract_patches(images, size=self.patch_size)
        patches = ops.reshape(
            patches,
            (
                batch_size,
                num_patches_h * num_patches_w,
                self.patch_size * self.patch_size * channels,
            ),
        )
        return patches

    def get_config(self):
        config = super().get_config()
        config.update({"patch_size": self.patch_size})
        return config

#Let's display patches for a sample image
plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]
plt.imshow(image.astype("uint8"))
plt.axis("off")

resized_image = ops.image.resize(
    ops.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = ops.reshape(patch, (patch_size, patch_size, 1))
    plt.imshow(ops.convert_to_numpy(patch_img).astype("uint8"))
    plt.axis("off")

#Implement the patch encoding layer
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = ops.expand_dims(
            ops.arange(start=0, stop=self.num_patches, step=1), axis=0
        )
        projected_patches = self.projection(patch)
        encoded = projected_patches + self.position_embedding(positions)
        return encoded

    def get_config(self):
        config = super().get_config()
        config.update({"num_patches": self.num_patches})
        return config

#Build the ViT model
def create_vit_classifier(ATTENTION_ARCH=StandardMultiHeadAttention, num_of_heads = 1):
    inputs = keras.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a attention layer.
        attention_output= ATTENTION_ARCH(
            num_heads=num_of_heads, key_dim=int(projection_dim/num_of_heads), dropout=0.1
        )(x1, x1)

        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.2)(representation)
    # representation = layers.Dense(1024, activation="silu")(representation)
    # representation = layers.Dropout(0.2)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.2)
    # Classify outputs.
    logits = layers.Dense(num_classes)(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

# create a learning rate scheduler callback
lr_step = 2 ** ((math.log(min_learning_rate, 2) - math.log(max_learning_rate, 2))/num_epochs)
def lr_scheduler(epoch, lr):
    return lr * lr_step


#Compile, train, and evaluate the mode
def run_experiment(model, arch_name="StandardMultiHeadAttention", run_number=1, num_heads = 1):
    optimizer = keras.optimizers.AdamW(
        learning_rate=max_learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )
    # model.summary()

    checkpoint_filepath = "./results/mnist/model/" + arch_name + f"_{num_heads}_heads" + "/run_num_" + str(run_number) + "/" + "model.weights.h5"
    history_filepath = "./results/mnist/history/" + arch_name + f"_{num_heads}_heads" + "/run_num_" + str(run_number) + "/history"+".csv"
    test_history_filepath = "./results/mnist/history/" + arch_name + f"_{num_heads}_heads" + "/run_num_" + str(run_number) + "/test_history"+".csv"
    general_info_filepath = "./results/mnist/history/" + arch_name + f"_{num_heads}_heads" + "/run_num_" + str(run_number) + "/general_info"+".csv"
    # create the directories if not exist
    os.makedirs(os.path.dirname(checkpoint_filepath), exist_ok=True)
    os.makedirs(os.path.dirname(history_filepath), exist_ok=True)
    os.makedirs(os.path.dirname(test_history_filepath), exist_ok=True)
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    time_callback = TimeHistory()

    # learning rate scheduler callback
    lr_scheduler_callback = keras.callbacks.LearningRateScheduler(lr_scheduler, verbose=1)

    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[checkpoint_callback, time_callback, lr_scheduler_callback],
    )

    # Save History till 5 decimal places
    history_df = pd.DataFrame(history.history)
    history_df = history_df.round(5)
    history_df.to_csv(history_filepath, sep='\t', index=False)

    model.load_weights(checkpoint_filepath)
    loss, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)

    # print(f"Test loss: {round(loss, 3)}")
    # print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    # print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    # Save the test results with 5 decimal places
    test_history_df = pd.DataFrame([[round(loss, 5), round(accuracy, 5), round(top_5_accuracy, 5)]], columns=["loss", "accuracy", "top-5-accuracy"])
    # print the headers in the first row and then the values in the second row
    test_history_df.to_csv(test_history_filepath, sep='\t', index=False)
    # save to general info file number of attention parameters and
    num_of_attention_params = model.layers[5].count_params()
    average_epoch_time_without_first_epoch = sum(time_callback.times[1:]) / (len(time_callback.times) - 1)
    average_epoch_time = sum(time_callback.times) / len(time_callback.times)
    #save to general info file with first row as header and second row as values with 3 decimal places
    general_info_pd = pd.DataFrame([[num_of_attention_params,
                                     round(average_epoch_time_without_first_epoch, 3), round(average_epoch_time, 3)]],
                                   columns=["num_of_attention_params",
                                            "average_epoch_time_excluding_first_epoch", "average_epoch_time"])
    general_info_pd.to_csv(general_info_filepath, sep='\t', index=False)

    return history



def plot_history(item, history, model_name="StandardMultiHeadAttention", run_number=0):
    title = "Train and Validation {} for {} in run {}".format(item, model_name, run_number)
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title(title, fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()


# Different Archihtectures for Attention

ATTENTION_ARCHS = [StandardMultiHeadAttention, OptimisedAttention, EfficientAttention, SuperAttention]
NUM_OF_HEADS = [4, 2, 1]
NUM_OF_RUNS = 5
# use enumerate to get the index of each element
for run_number in range(NUM_OF_RUNS):
    for attention_arch in ATTENTION_ARCHS:
        for num_of_heads in NUM_OF_HEADS:
            vit_classifier = create_vit_classifier(attention_arch, num_of_heads)
            history = run_experiment(vit_classifier, attention_arch.__name__, run_number, num_of_heads)
            plot_history("loss", history, attention_arch.__name__, run_number)
            plot_history("accuracy", history, attention_arch.__name__, run_number)
            plot_history("top-5-accuracy", history, attention_arch.__name__, run_number)
